# Todo:
# * reorg. why are markers per labeled table? shouldn't there be default setting?

from __future__ import division


from core.itertools2 import column

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


class Attribute(object):
    def __init__(self):
        self.marker = 'o'
        self.style = 'solid'
        self.color = None


class Legend(object):
    def __init__(self, ax):
        self.ax = ax
        self._orders = []

    def plot(self):
        handles, labels = self.ax.get_legend_handles_labels()
        order = dict(zip(labels, handles))
        new_order = []
        for i in self._orders:
            new_order.append((i, order[i]))
            del order[i]

        new_order += order.items()

        if len(new_order) > 0:
            self.ax.legend(column(new_order, 1), column(new_order, 0), loc=0)

    def set_order(self, *orders):
        self._orders = orders


class Plot(object):
    """ Wrapper around matplotlib plotting that can easily plot
    `class:DataTable` objects.  Generally I find matplotlib to have an awkward
    API which makes interaction difficult.  This abstracts many of the common
    operations."""

    def __init__(self, *tables, **named_tables):
        if len(named_tables) > 0 and len(tables) > 0:
            raise ValueError('Must provide only one of either named or unnamed'
                + ' tables')

        self.plotted = False
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.plt = plt

        self.xlim = None
        self.ylim = None

        if tables:
            named_tables = list(enumerate(tables))
        else:
            named_tables = named_tables.items()

        self.named_tables = named_tables
        self.ax.set_xlabel(self.named_tables[0][1].column_names[0])
        if len(self.named_tables[0][1].column_names) > 1:
            self.ax.set_ylabel(self.named_tables[0][1].column_names[1])

        self._attributes = {}
        self._legend = None
        self.yerr = None
        self.grid = False

    def errorbar(self, values):
        self.yerr = values

    def plot(self):
        if self.plotted:
            return

        self.plotted = True
        for name, dt in self.named_tables:
            c = dt.columns()

#            if dt.width() == 2:
#                args.append('o')

            label = name
            color = self._attributes.get(label, Attribute()).color
            linestyle = self._attributes.get(label, Attribute()).style
            marker = self._attributes.get(label, Attribute()).marker
            kws = {'label': label}

            if marker:
               kws['marker'] = marker
            if color:
                kws['color'] = color
            if linestyle:
                kws['linestyle'] = linestyle
            else:
                kws['linestyle'] = 'None'

            args = [list(c[0]), list(c[1])]

            if self.grid:
                self.ax.grid(True)

            if self.yerr:
                self.ax.errorbar(args[0], args[1], yerr=self.yerr)

            self.ax.plot(*args, **kws)

        if len(self.named_tables) > 0:  # XXX
            l = self.legend()
            l.plot()

        if self.ylim:
            self.plt.ylim(self.ylim)
        if self.xlim:
            self.plt.xlim(self.xlim)

    def legend(self):
        if not self._legend:
            self._legend = Legend(self.ax)

        return self._legend

    def set_color(self, key, color):
        self._attributes.setdefault(key, Attribute()).color = color

    def set_style(self, key, style):
        self._attributes.setdefault(key, Attribute()).style = style

    def set_marker(self, key, marker):
        self._attributes.setdefault(key, Attribute()).marker = marker

    @property
    def title(self):
        return self.ax._title

    @title.setter
    def title(self, title):
        self.ax.set_title(title)

    @property
    def ylabel(self):
        return self.ax.ylabel

    @ylabel.setter
    def ylabel(self, label):
        self.ax.set_ylabel(label)

    @property
    def xlabel(self):
        return self.ax.xlabel

    @xlabel.setter
    def xlabel(self, label):
        self.ax.set_xlabel(label)

    def save(self, filename):
        self.plot()
        plt.savefig('%s.pdf' % filename, format='pdf')
